Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

GH-45506: [C++][Acero] More overflow-safe Swiss table #45515

Merged
merged 6 commits into from
Feb 17, 2025

Conversation

zanmato1984
Copy link
Contributor

@zanmato1984 zanmato1984 commented Feb 12, 2025

Rationale for this change

See #45506.

What changes are included in this PR?

  1. Abstract current overflow-prone block data access into functions that do proper type promotion to avoid overflow. Also remove the old block base address accessor.
  2. Unify the data types used for various concepts as they naturally are (i.e., w/o explicit promotion): uint32_t for block_id, int for num_xxx_bits/bytes, uint32_t for group_id, int for local_slot_id and uint32_t for global_slot_id.
  3. Abstract several constants and utility functions for readability and maintainability.

Are these changes tested?

Existing tests should suffice.

It is really hard (gosh I did try) to create a concrete test case that fails w/o this change and passes w/ this change.

Are there any user-facing changes?

None.

@zanmato1984
Copy link
Contributor Author

Most of this change are cleanup and refinement. @pitrou mind to take a look? Thanks.

@github-actions github-actions bot added the awaiting review Awaiting review label Feb 12, 2025
@@ -643,37 +643,36 @@ void SwissTableMerge::MergePartition(SwissTable* target, const SwissTable* sourc
//
int source_group_id_bits =
SwissTable::num_groupid_bits_from_log_blocks(source->log_blocks());
uint64_t source_group_id_mask = ~0ULL >> (64 - source_group_id_bits);
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

After cleaning up these unnecessary 64-bit in this file, we can further cleanup some temp states as mentioned in #45336 (comment) .

@github-actions github-actions bot added awaiting committer review Awaiting committer review and removed awaiting review Awaiting review labels Feb 12, 2025
@zanmato1984 zanmato1984 force-pushed the more-overflow-safe-swiss-table branch from af1c470 to 1dd5c08 Compare February 12, 2025 12:53
@@ -81,18 +81,29 @@ class ARROW_EXPORT SwissTable {

void num_inserted(uint32_t i) { num_inserted_ = i; }

uint8_t* blocks() const { return blocks_->mutable_data(); }
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is the source of all evil. Let's get rid of it!

Comment on lines 119 to 122
uint32_t group_id = *reinterpret_cast<const uint32_t*>(
block_data(block_id, num_block_bytes) + local_slots[id] * num_groupid_bytes +
bytes_status_in_block_);
group_id &= group_id_mask;
Copy link
Member

@pitrou pitrou Feb 12, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So we always issue a 32-bit load but then we optionally mask if the actual group id width is smaller? Don't we risk reading past block_data bounds here?

(also, should we use an unaligned load? see the SafeLoad and SafeLoadAs utility functions)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So we always issue a 32-bit load but then we optionally mask if the actual group id width is smaller? Don't we risk reading past block_data bounds here?

There will always be padding_ (64) extra bytes at the buffer end.

(also, should we use an unaligned load? see the SafeLoad and SafeLoadAs utility functions)

It seems so indeed, though I didn't change how the original code does it.

I'll update later.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(also, should we use an unaligned load? see the SafeLoad and SafeLoadAs utility functions)

It seems so indeed, though I didn't change how the original code does it.

I'll update later.

OK, turns out I was wrong. The original code uses aligned read and my change made it unaligned. I'll need to update it with more care. Thank you for pointing this out.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Changed back to the original aligned read with minor refinement.

Besides, I've also did some more cleanup during fixing the alignment issue. See my latest commit. Thanks.

return static_cast<uint32_t>((1ULL << num_groupid_bits) - 1);
}

static constexpr int bytes_status_in_block_ = 8;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Usually, compile-time constants should follow the naming convention kBytesStatusInBlock. Perhaps we can have a renaming pass in this PR or another one?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes it is supposed to be. However I was also following the naming convention of several existing compile time constants in this class. I would like to to change them all in another PR to keep this one solely focused on the purpose the overflow prevention.

Comment on lines +387 to +389
uint32_t mask = num_groupid_bytes == 1 ? 0xFF
: num_groupid_bytes == 2 ? 0xFFFF
: 0xFFFFFFFF;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there a reason for expanding the possible values instead of simply using the usual bitshift formula?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not particularly. This is just moving the original code.

__m128i group_id_lo = _mm256_i64gather_epi32(elements, pos_lo, 1);
__m128i group_id_hi = _mm256_i64gather_epi32(elements, pos_hi, 1);
local_slot_lo =
_mm256_mul_epu32(local_slot_lo, _mm256_set1_epi32(num_groupid_bytes));
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is using a 32-bit multiply even though local_slot_lo is supposed to be a vector of 64-bit ints? This might be correct because most bytes are zero, but I would at least expect an explanatory comment :)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, my bad, _mm256_mul_epu32 is actually a 64-bit multiply.

@zanmato1984 zanmato1984 force-pushed the more-overflow-safe-swiss-table branch from 1f9dde4 to aeb8b9d Compare February 13, 2025 14:40
uint64_t group_id_mask) const;
static uint32_t extract_group_id(const uint8_t* block_ptr, int local_slot,
int num_group_id_bits) {
// Extract group id using aligned 32-bit read.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why not use SafeLoad as in other places already?
(also, since this is non-trivial, factoring out the loading of a group id could go into a dedicated inline function)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The original code uses aligned read + masking so I'm following it, possibly for performance sake I guess?

If SafeLoad is preferred (i.e., it doesn't hurt performance), then yes it is possible to factor this piece of code out.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For the record, there are three places doing group id extraction:

  1. Here, extracting single group id, publicly used by swiss join: currently using aligned read + masking;
  2. extract_group_ids, extracting a vector of group ids, internally used: using aligned read w/o masking (the number of bits is constant-ized as template parameter);
  3. grow_double, extracting single group id inside a big loop, inlined: using unaligned read + masking.

I think we should at least keep 2) as is because it makes perfect sense. 1) and 3) can be unified, either aligned or unaligned.

What do you think?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we should at least keep 2) as is because it makes perfect sense. 1) and 3) can be unified, either aligned or unaligned.

Agreed. Feel free to choose whatever approach you prefer!

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Addressed. Thank you for the suggestion.

This PR is good for review again.

Copy link
Member

@pitrou pitrou left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just two comments. I don't really know the internals of swiss table so I just skipped some parts :)

int64_t block_id = hash >> (SwissTable::bits_hash_ - target->log_blocks());
int64_t block_id_mask = ((1LL << target->log_blocks()) - 1);
uint32_t block_id = SwissTable::block_id_from_hash(hash, target->log_blocks());
uint32_t block_id_mask = (1 << target->log_blocks()) - 1;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would a signed int shift be UB if target->log_blocks() is 31? Or does that not happen anyway?

Copy link
Contributor Author

@zanmato1984 zanmato1984 Feb 17, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The log_blocks() is guaranteed to be <= 29: the maximum of number of rows of a swiss table is 2^32 (we already have many guards on this), and each block contains 8 rows/slots. So the UB won't be happening.

uint64_t global_slot_id_mask = (1 << (log_blocks_ + 3)) - 1;
uint32_t SwissTable::wrap_global_slot_id(uint32_t global_slot_id) const {
uint32_t global_slot_id_mask =
static_cast<uint32_t>((1ULL << (log_blocks_ + kLogSlotsPerBlock)) - 1ULL);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Interestingly we're still using a 64-bit unsigned shift here (see previous comment).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

However, the maximal number of slots/rows is 2^32 so 64-bit unsigned shifting is used.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ahah, ok. Thanks for the explanation!

Copy link
Member

@pitrou pitrou left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks a lot for the cleanups and improvements @zanmato1984 ! CI failures are unrelated.

@pitrou pitrou merged commit a53a77c into apache:main Feb 17, 2025
40 of 42 checks passed
@pitrou pitrou removed the awaiting committer review Awaiting committer review label Feb 17, 2025
Copy link

After merging your PR, Conbench analyzed the 4 benchmarking runs that have been run so far on merge-commit a53a77c.

There were no benchmark performance regressions. 🎉

The full Conbench report has more details. It also includes information about 8 possible false positives for unstable benchmarks that are known to sometimes produce them.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants